"""
Codes for Fig. 1 bcd plots. 

Yuechuan Lin 
2022-03-30
Cornell University
"""
import matplotlib.pyplot as plt 
import numpy as np 
import pandas as pd 
import matplotlib 
from scipy import signal 
import matplotlib.ticker as tck
import matplotlib.ticker 
from matplotlib import gridspec
import scipy.io as sio
import h5py
import scipy.signal as ssignal
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)

#print(matplotlib.matplotlib_fname()) # show where the font of matplotlib is 

# set font of plot 
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Helvetica']
font = {'weight': 'normal',
        'size'   : 15}
matplotlib.rc('font', **font)
matplotlib.rc('text', usetex=True)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['mathtext.default'] = 'regular'


#matplotlib.font_manager._rebuild() # refresh Matplotlib font caches; only have to run once if you need to re-set all systematic fonts
def set_share_axes(axs, target=None, sharex=False, sharey=False):
    if target is None:
        target = axs.flat[0]
    # Manage share using grouper objects
    for ax in axs.flat:
        if sharex:
            target._shared_x_axes.join(target, ax)
        if sharey:
            target._shared_y_axes.join(target, ax)
    # Turn off x tick labels and offset text for all but the bottom row
    if sharex and axs.ndim > 1:
        for ax in axs[:-1,:].flat:
            ax.xaxis.set_tick_params(which='both', labelbottom=False, labeltop=False)
            ax.xaxis.offsetText.set_visible(False)
    # Turn off y tick labels and offset text for all but the left most column
    if sharey and axs.ndim > 1:
        for ax in axs[:,1:].flat:
            ax.yaxis.set_tick_params(which='both', labelleft=False, labelright=False)
            ax.yaxis.offsetText.set_visible(False)


# now start to plot 
fig = plt.figure(constrained_layout=False,figsize=(7.2,4))
plt.tight_layout(pad=0.0)
#plt.margins(0,0)
plt.subplots_adjust(wspace=0.4,hspace=0.4)

ax0 = plt.subplot2grid((2,3),(0,1))  # PSF
ax1 = plt.subplot2grid((2,3),(1,0))   # PSF fit 
ax2 = plt.subplot2grid((2,3),(0,2),rowspan=2)  # Frad vs. x; GLMT; PF-beam-Intensity_fit
ax3 = plt.subplot2grid((2,3),(1,1))  # Frad vs. power 
#ax4 = plt.subplot2grid((2,3),(1,2)) # depeth dependent force profile and OCT intensity profile 


## plot PSF 

PSFdata_dict_tmp = sio.loadmat('./Fig1_data/BeamProfile.mat')
PSFdata_dict = PSFdata_dict_tmp['BeamProfile']
PSFdata = PSFdata_dict[0,0]
xShort = PSFdata['xShort']
yShort = PSFdata['yShort']
xLong = PSFdata['xLong']
yLong = PSFdata['yLong']
imdata = PSFdata['imdata']
longGof = PSFdata['longGof']
shortGof = PSFdata['shortGof']
yfitShort = PSFdata['fityShort']
yfitLong = PSFdata['fityLong']

imdata = imdata/(np.amax(imdata))  # normalize the image data
# pad zeros to the short axis of im 
if np.shape(imdata)[0] > np.shape(imdata)[1]:
    longAx = np.shape(imdata)[0]
    shortAx = np.shape(imdata)[1]
else:
    longAx = np.shape(imdata)[1]
    shortAx = np.shape(imdata)[0]


imdata_ex = np.ones((longAx,longAx))
cutLength = int((longAx-shortAx)/2)
if np.shape(imdata)[0] > np.shape(imdata)[1]:
    imdata_ex = imdata_ex*np.median(imdata_ex[:,1:10])
    imdata_ex[:,cutLength:cutLength+shortAx] = imdata
else:
    imdata_ex = imdata_ex*np.median(imdata[1:10,:])
    imdata_ex[cutLength:cutLength+shortAx,:] = imdata

im = ax0.imshow(imdata_ex,cmap='hot')
pixelSize = 5.2/(16.1980*2) 
scaleBarLength = 20 # micro meter
scaleBar_x = 60 + np.linspace(0,scaleBarLength/pixelSize,100)
ax0.plot(scaleBar_x,np.ones(np.shape(scaleBar_x))*1030,linewidth=2,color='w')
ax0.text(10,1050,r'20$\mu$m',fontsize=8,color='w')
ax0.set_xlabel(r'x-axis (pixels)',labelpad=-3)
ax0.set_ylabel(r'y-axis (pixels)',labelpad=4)
ax0.set_xlim([0,longAx])
ax0.set_ylim([0,longAx])
ax0.get_xaxis().set_visible(False)
ax0.get_yaxis().set_visible(False)
ax0.text(5,33,r'(a)',fontsize=16,color='w')
# inser colorbar
axins0 = inset_axes(ax0,
                    width="50%",  # width = 50% of parent_bbox width
                    height="5%",  # height : 5%
                    loc='lower right')
cbar = fig.colorbar(im,cax=axins0,cmap='jet',orientation='horizontal',fraction=0.05)
cbar.set_clim([0,1])
axins0.xaxis.set_ticks_position("top")
cbar.ax.xaxis.set_tick_params(color='white')
cbar.outline.set_edgecolor('white')
plt.setp(plt.getp(cbar.ax.axes, 'xticklabels'), color='white',fontsize=8)


# plot PSF fit 
PSFlong_raw = yLong
PSFlong_fit = yfitLong
PSFlong_rawAxis = xLong
PSFlong_fitAxis = xLong


PSFshort_raw = yShort
PSFshort_fit = yfitShort
PSFshort_rawAxis = xShort
PSFshort_fitAxis = xShort
ax1.scatter(PSFlong_rawAxis+4.5,PSFlong_raw,15,color='orange',marker='o',edgecolors='face',linewidth=0,alpha=0.8,label='long axis')
ax1.plot(PSFlong_fitAxis+4.5,PSFlong_fit,color='C5',linewidth=1.5,label='long axis fit',linestyle="-")
#ax1.annotate(r'$\omega_{l}=78.5\mu$m',xy=(-40,0.5),xytext=(-60,0.15),arrowprops=dict(arrowstyle='->',facecolor='red'),size=10)
ax1.set_xlabel(r'Axis ($\mu$m)',labelpad=-2)
ax1.set_yticks(list(ax1.get_yticks()) + [0.5])
ax1.set_ylim([-0.01,1.02])
ax1.set_xlim([-80,80])
ax1.set_ylabel(r'Amplitude',labelpad=0)
ax1.text(-76,0.85,r'(b)',fontsize=16)
ax1.tick_params(axis='x', colors='orange',labelcolor='orange',which='both',direction='in')
ax1.xaxis.set_minor_locator(AutoMinorLocator(2))
ax1.yaxis.set_minor_locator(AutoMinorLocator(2))
ax11 = ax1.twiny()
ax11.scatter(PSFshort_rawAxis,PSFshort_raw,15,color='blue',marker='o',edgecolors='face',linewidth=0,alpha=0.8,label='long axis')
ax11.plot(PSFshort_fitAxis,PSFshort_fit,color='C5',linestyle="-",linewidth=1.5,label='short axis fit')
ax11.set_xlim([-10,10])

ax11.spines['bottom'].set_color('orange')
ax11.spines['top'].set_color('blue')
ax11.xaxis.set_minor_locator(AutoMinorLocator(2))
ax11.tick_params(axis='x', colors='blue',labelcolor = 'blue',direction='in',which='both',top = True,bottom=False)



# now plot depth-dependent force profile measured under M-Mode 
#forceCal_raw = h5py.File('C:/Users/yl3248/Desktop/BiophotonicsConference/forceCal_MMode.mat','r')
zprofile = sio.loadmat('./Fig1_data/Frad_depthProfile_zaxis.mat')
#forceCal = forceCal_raw['forceCal']
#print(forceCal['z'])
Zz = zprofile['Fradzaxis']#[0,0]
print(Zz.shape)

FradZData_tmp = sio.loadmat('./Fig1_data/FradZcenter.mat')
FradZData_dict = FradZData_tmp['FradZcenter']
MModez = FradZData_dict['zData'][0,0]
MModeFradz = FradZData_dict['FradData'][0,0]
medForce_list = FradZData_dict['FradMedian'][0,0]
medForce = medForce_list[:,0] 
medForce = ssignal.medfilt(medForce,kernel_size=5)
stdForce = medForce_list[:,1] 
stdForce = ssignal.medfilt(stdForce,kernel_size=5) 
Zz = FradZData_dict['zMedian'][0,0] 
 

print(medForce.shape)
print(Zz.shape)
 
ax2.scatter(MModeFradz,MModez,20,color='#1f77b4',marker='o',edgecolors='face',linewidth=0,alpha=0.05,label='Frad',rasterized=True)
ax2.scatter(medForce,Zz,1,color= 'k',marker='o',edgecolors='face',linewidth=0,label='median') 
ax2.scatter(stdForce+medForce,Zz,0.05,color='#d62728',marker='o',edgecolors='face',label='std')
ax2.scatter(medForce-stdForce,Zz,0.05,color='#d62728',marker='o',edgecolors='face')
ax2.set_xticks(list(ax2.get_xticks()) + [2.0,4.0])
ax2.set_xlim([-0.1,3.5])
ax2.set_ylim([-180,180])
ax2.set_xlabel(r'Force (pN)',labelpad=-4)
ax2.set_ylabel(r'$z$ ($\mu$m)',labelpad=-18) 
leg = ax2.legend(loc=1,fontsize=10,markerscale=2,labelspacing =0.1,columnspacing=0.5,borderpad=0.2,fancybox=True,handletextpad=0.5,borderaxespad=0.2) 
for lh in leg.legendHandles: 
    lh.set_alpha(0.7)
    lh.set_sizes([6.0])

ax2.text(2.8,-170,r'(c)',fontsize=16)
ax2.set_xticks((0, 1.5, 3.0))
ax2.set_yticks((-150, -100, -50, 0.0, 50, 100, 150))



## Now plot long-axis lightsheet force profile 
simData_raw = sio.loadmat('./Fig1_data/simRaw.mat')
simData = simData_raw['sim']

MeaData_tmp = sio.loadmat('./Fig1_data/FradX.mat')
MeaData_dict = MeaData_tmp['FradX'] 
MeaData_x = MeaData_dict['xData'][0,0] 
MeaData_force_all = MeaData_dict['FradData'][0,0] 
MeaData_force = MeaData_force_all[:,0]
MeaData_delta = MeaData_force_all[:,1] 
print(MeaData_delta.shape)
print(MeaData_x.shape)
print(MeaData_force.shape)
print("Maximum sim data is:{}".format(np.amax(simData[:,3]/10)*125/100))
ax3.plot(simData[:,0],(simData[:,3]/10)*0.86*125/100,label='GLMT',zorder=0) # 0.9 taking into account of loss through coverslip glass of sample 
ax3.errorbar(MeaData_x,MeaData_force,yerr=MeaData_delta,label='Measure',capsize=3.0,fmt='o',elinewidth =1.0, mfc='white', zorder=1,markersize=5,facecolor=None)

ax3.set_yticks([0,1.5,3])
ax3.yaxis.set_minor_locator(AutoMinorLocator(2))
ax3.set_xlim([-80,80]) 
ax3.set_ylim([-0.1,3.5])
ax3.legend(loc=8,fontsize=6,markerscale=0.5,labelspacing =0.2,columnspacing=1.0) 
ax3.set_ylabel(r'Force (pN)',labelpad=0)
ax3.set_xlabel(r'Lateral axis ($\mu$m)',labelpad=-3)
ax3.text(-76,3.5*0.76,r'(d)',fontsize=16)



plt.show()